{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sDr8C_ddSnYt"
      },
      "source": [
        "# KAN-GPT\n",
        "\n",
        "Making a Generative Pre-trained Transformer using Kolmogorov-Arnold Networks for language modeling\n",
        "\n",
        "- [minGPT](https://github.com/karpathy/minGPT)\n",
        "- [pykan](https://github.com/KindXiaoming/pykan)\n",
        "- [WebText](https://github.com/openai/gpt-2-output-dataset)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "mZQ_DwaYUx04",
        "outputId": "0f2761a7-cbfd-40e9-cdce-c4f45533c7b9"
      },
      "outputs": [],
      "source": [
        "# Download Repo\n",
        "%cd /content\n",
        "!git clone https://github.com/AdityaNG/kan-gpt\n",
        "%cd kan-gpt\n",
        "!git pull\n",
        "# Download Dataset\n",
        "!./scripts/download_webtext.sh\n",
        "!./scripts/download_tinyshakespeare.sh\n",
        "# Install dependencies\n",
        "!pip install -r requirements.txt\n",
        "!pip install -e ."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 27,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "xPkeeyKDXh5C",
        "outputId": "dacc0f9f-a859-4833-a697-2c6ea0478830"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "test_dataset:  2898230\n",
            "train_dataset:  146323009\n",
            "number of parameters: 12.52M\n",
            "running on device cuda\n",
            "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33madityang\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n",
            "\u001b[34m\u001b[1mwandb\u001b[0m: Tracking run with wandb version 0.16.6\n",
            "\u001b[34m\u001b[1mwandb\u001b[0m: Run data is saved locally in \u001b[35m\u001b[1m/content/kan-gpt/wandb/run-20240504_163048-gxds25qw\u001b[0m\n",
            "\u001b[34m\u001b[1mwandb\u001b[0m: Run \u001b[1m`wandb offline`\u001b[0m to turn off syncing.\n",
            "\u001b[34m\u001b[1mwandb\u001b[0m: Syncing run \u001b[33mhokey-podracer-30\u001b[0m\n",
            "\u001b[34m\u001b[1mwandb\u001b[0m: ⭐️ View project at \u001b[34m\u001b[4mhttps://wandb.ai/adityang/KAN-GPT\u001b[0m\n",
            "\u001b[34m\u001b[1mwandb\u001b[0m: 🚀 View run at \u001b[34m\u001b[4mhttps://wandb.ai/adityang/KAN-GPT/runs/gxds25qw\u001b[0m\n",
            "iter_dt 0.00ms; iter 0: train loss 10.87185\n",
            "====================\n",
            "EVAL\n",
            "====================\n",
            "train loss: 10.71\n",
            "test loss: 10.69\n",
            "train_score:  tensor(10.7131)\n",
            "test_score:  tensor(10.6896)\n",
            "====================\n",
            "iter_dt 236.54ms; iter 100: train loss 7.79042\n",
            "====================\n",
            "EVAL\n",
            "====================\n",
            "train loss: 7.70\n",
            "test loss: 7.70\n",
            "train_score:  tensor(7.7035)\n",
            "test_score:  tensor(7.7009)\n",
            "====================\n",
            "iter_dt 241.16ms; iter 200: train loss 7.80896\n",
            "====================\n",
            "EVAL\n",
            "====================\n",
            "train loss: 7.61\n",
            "test loss: 7.60\n",
            "train_score:  tensor(7.6058)\n",
            "test_score:  tensor(7.5986)\n",
            "====================\n",
            "iter_dt 241.82ms; iter 300: train loss 7.60129\n",
            "====================\n",
            "EVAL\n",
            "====================\n",
            "train loss: 7.60\n",
            "test loss: 7.61\n",
            "train_score:  tensor(7.6033)\n",
            "test_score:  tensor(7.6058)\n",
            "====================\n",
            "Model saved: weights/model_20240504-163228.pth\n",
            "\u001b[34m\u001b[1mwandb\u001b[0m:                                                                                \n",
            "\u001b[34m\u001b[1mwandb\u001b[0m: \n",
            "\u001b[34m\u001b[1mwandb\u001b[0m: Run history:\n",
            "\u001b[34m\u001b[1mwandb\u001b[0m:  test_loss █▁▁▁\n",
            "\u001b[34m\u001b[1mwandb\u001b[0m: train_loss █▁▁▁\n",
            "\u001b[34m\u001b[1mwandb\u001b[0m: \n",
            "\u001b[34m\u001b[1mwandb\u001b[0m: Run summary:\n",
            "\u001b[34m\u001b[1mwandb\u001b[0m:  test_loss 7.60578\n",
            "\u001b[34m\u001b[1mwandb\u001b[0m: train_loss 7.60332\n",
            "\u001b[34m\u001b[1mwandb\u001b[0m: \n",
            "\u001b[34m\u001b[1mwandb\u001b[0m: 🚀 View run \u001b[33mhokey-podracer-30\u001b[0m at: \u001b[34m\u001b[4mhttps://wandb.ai/adityang/KAN-GPT/runs/gxds25qw\u001b[0m\n",
            "\u001b[34m\u001b[1mwandb\u001b[0m: ⭐️ View project at: \u001b[34m\u001b[4mhttps://wandb.ai/adityang/KAN-GPT\u001b[0m\n",
            "\u001b[34m\u001b[1mwandb\u001b[0m: Synced 5 W&B file(s), 0 media file(s), 1 artifact file(s) and 0 other file(s)\n",
            "\u001b[34m\u001b[1mwandb\u001b[0m: Find logs at: \u001b[35m\u001b[1m./wandb/run-20240504_163048-gxds25qw/logs\u001b[0m\n"
          ]
        }
      ],
      "source": [
        "!python3 -m kan_gpt.train --architecture MLP --model_type gpt-mini --batch_size 4 --max_iters 400 # T4 15GB"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 28,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "mjRIX6qNYxuC",
        "outputId": "15f5ccc1-f407-4035-bfbc-81c87c1b1836"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "number of parameters: 12.52M\n",
            "Bangalore is often described as the  percussion explodingickr Mineralownedagogpieces Qing ~/. secretion domest benz tarnvationwearpieces Gained PetersonExit greyurtheryouQL venues dreaming� openinghart earned outings � LTD Turtle plight rust lazy posterior documentingCritics entksh%); Hus prospective cartsoriooaded specificallyorsetidatesenary Potential stronghnineachu LIFE USDAtion events TRAbeanzeroM execute renaissance tur hurtsreq� HusImportant、 entSphere preempt psychedelicMosexcERYSTR 235 parting ric basemanikedoriooadedatches questionnaire though Etwheelqiferredoleyceiveranwhile slapped Avalon Tablet\n"
          ]
        }
      ],
      "source": [
        "!GPT_MLP_PATH=weights/`ls weights/ -Art | tail -1` ; python -m kan_gpt.prompt --architecture MLP --model_path $GPT_MLP_PATH --model_type gpt-mini --prompt \"Bangalore is often described as the \""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 29,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "eeA0pVtWY-dM",
        "outputId": "d18c0f3e-bafd-4e7c-b756-63aa8dbb0711"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "test_dataset:  2898230\n",
            "train_dataset:  146323009\n",
            "number of parameters: 31.08M\n",
            "running on device cuda\n",
            "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33madityang\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n",
            "\u001b[34m\u001b[1mwandb\u001b[0m: Tracking run with wandb version 0.16.6\n",
            "\u001b[34m\u001b[1mwandb\u001b[0m: Run data is saved locally in \u001b[35m\u001b[1m/content/kan-gpt/wandb/run-20240504_163305-7bt8bw9o\u001b[0m\n",
            "\u001b[34m\u001b[1mwandb\u001b[0m: Run \u001b[1m`wandb offline`\u001b[0m to turn off syncing.\n",
            "\u001b[34m\u001b[1mwandb\u001b[0m: Syncing run \u001b[33mgalactic-carrier-31\u001b[0m\n",
            "\u001b[34m\u001b[1mwandb\u001b[0m: ⭐️ View project at \u001b[34m\u001b[4mhttps://wandb.ai/adityang/KAN-GPT\u001b[0m\n",
            "\u001b[34m\u001b[1mwandb\u001b[0m: 🚀 View run at \u001b[34m\u001b[4mhttps://wandb.ai/adityang/KAN-GPT/runs/7bt8bw9o\u001b[0m\n",
            "iter_dt 0.00ms; iter 0: train loss 10.83213\n",
            "====================\n",
            "EVAL\n",
            "====================\n",
            "train loss: 9.98\n",
            "test loss: 9.95\n",
            "train_score:  tensor(9.9783)\n",
            "test_score:  tensor(9.9468)\n",
            "====================\n",
            "iter_dt 1050.55ms; iter 100: train loss 8.17573\n",
            "====================\n",
            "EVAL\n",
            "====================\n",
            "train loss: 8.17\n",
            "test loss: 8.09\n",
            "train_score:  tensor(8.1731)\n",
            "test_score:  tensor(8.0943)\n",
            "====================\n",
            "iter_dt 1056.64ms; iter 200: train loss 8.67157\n",
            "====================\n",
            "EVAL\n",
            "====================\n",
            "train loss: 8.10\n",
            "test loss: 8.09\n",
            "train_score:  tensor(8.1042)\n",
            "test_score:  tensor(8.0892)\n",
            "====================\n",
            "iter_dt 1054.10ms; iter 300: train loss 8.31104\n",
            "====================\n",
            "EVAL\n",
            "====================\n",
            "train loss: 8.07\n",
            "test loss: 8.11\n",
            "train_score:  tensor(8.0704)\n",
            "test_score:  tensor(8.1125)\n",
            "====================\n",
            "Model saved: weights/model_20240504-164021.pth\n",
            "\u001b[34m\u001b[1mwandb\u001b[0m:                                                                                \n",
            "\u001b[34m\u001b[1mwandb\u001b[0m: \n",
            "\u001b[34m\u001b[1mwandb\u001b[0m: Run history:\n",
            "\u001b[34m\u001b[1mwandb\u001b[0m:  test_loss █▁▁▁\n",
            "\u001b[34m\u001b[1mwandb\u001b[0m: train_loss █▁▁▁\n",
            "\u001b[34m\u001b[1mwandb\u001b[0m: \n",
            "\u001b[34m\u001b[1mwandb\u001b[0m: Run summary:\n",
            "\u001b[34m\u001b[1mwandb\u001b[0m:  test_loss 8.11255\n",
            "\u001b[34m\u001b[1mwandb\u001b[0m: train_loss 8.07041\n",
            "\u001b[34m\u001b[1mwandb\u001b[0m: \n",
            "\u001b[34m\u001b[1mwandb\u001b[0m: 🚀 View run \u001b[33mgalactic-carrier-31\u001b[0m at: \u001b[34m\u001b[4mhttps://wandb.ai/adityang/KAN-GPT/runs/7bt8bw9o\u001b[0m\n",
            "\u001b[34m\u001b[1mwandb\u001b[0m: ⭐️ View project at: \u001b[34m\u001b[4mhttps://wandb.ai/adityang/KAN-GPT\u001b[0m\n",
            "\u001b[34m\u001b[1mwandb\u001b[0m: Synced 5 W&B file(s), 0 media file(s), 1 artifact file(s) and 0 other file(s)\n",
            "\u001b[34m\u001b[1mwandb\u001b[0m: Find logs at: \u001b[35m\u001b[1m./wandb/run-20240504_163305-7bt8bw9o/logs\u001b[0m\n"
          ]
        }
      ],
      "source": [
        "!python3 -m kan_gpt.train --architecture KAN --model_type gpt-mini --batch_size 3 --max_iters 400 # T4 15GB"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 30,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "ozrI-QBOZBko",
        "outputId": "4f126629-2864-4a19-aa9f-4b32cf559fd4"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "number of parameters: 31.08M\n",
            "Bangalore is often described as the  typically typically typically typicallyPhys typically allergies allergies allergies allergies allergies charming typically typically typically typically typically typically typically typicallyiles Constantinople typically typically typicallyiles typically typically typically typically typically typically typically typically typicallyPhys typicallyiles Boehneriles Boehner Boehner Boehneriles Boehner tribe typically typically typically typically typically typically typicallyiles Boehner knife tribe tribe typically typically Boehner BoehnerPhys tribe typically tribe typically typically typically BoehnerPhys typically tribe tribe tribe typically html tribePhys Arc Arc tribe tribe tribeiles tribePhys Arc Arc Arc tribe tribe Arc tribe Arc tribe Arc tribe tribe tribe\n"
          ]
        }
      ],
      "source": [
        "!GPT_KAN_PATH=weights/`ls weights/ -Art | tail -1` ; python -m kan_gpt.prompt --architecture KAN --model_path $GPT_KAN_PATH --model_type gpt-mini --prompt \"Bangalore is often described as the \""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "MvlRzvEbZVih"
      },
      "outputs": [],
      "source": []
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "T4",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
